"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


plot of the results of the clustering of the asymmetric synthetic example (Fig. 2)
    

"""
import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))


import glob
import numpy as np
import pandas as pd
from TemporalNetwork import ContTempNetwork
from TemporalStability import norm_mutual_information, Clustering, FlowClustering, static_clustering
from copy import deepcopy

import matplotlib.pyplot as plt
plt.style.use('alex_paper')

import matplotlib as mpl



from palettable import cmocean, colorbrewer


mpl.rcParams['lines.linewidth'] = 1.7
divcmap = mpl.cm.PuOr_r
seqcmap = mpl.cm.Oranges

divcmap = cmocean.diverging.Curl_20.get_mpl_colormap()
seqcmap = mpl.colors.LinearSegmentedColormap.from_list('seq',
                                   cmocean.diverging.Curl_20.mpl_colors[10:])



divcmap = mpl.colors.LinearSegmentedColormap.from_list('seq',
                                   colorbrewer.diverging.PuOr_11_r.mpl_colors)

seqcmap = mpl.colors.LinearSegmentedColormap.from_list('seq',
                                   colorbrewer.diverging.PuOr_11_r.mpl_colors[5:])


datadir = '../paper_data/synthtempnet/paper_example/'

lapl_intertransmatdir = os.path.join(datadir, 'lapl_intertransmat')

clusterdir = os.path.join(datadir, 'clustersI')


figdir = '../figures/paper_figures'

file_prefix = 'synthtemp_paper_example'

net_file = os.path.join(datadir,'synthtemp_paper_example')


sim_param_file =  os.path.join(datadir, 'synthtemp_paper_example_sim_param.pickle')

raise Exception

#%%

transfiles = os.listdir(lapl_intertransmatdir)



intervals = set()
tau_ws = set()
for file in transfiles:
    extracts = os.path.splitext(file)[0].split('_')
    intervals.add(int(extracts[5][3:]))
    for extract in os.path.splitext(file)[0].split('_'):
        if extract.startswith('w'):
            tau_ws.add(float(extract[1:]))
            
            
            
#%%% net work file

net = ContTempNetwork.load(net_file)
sim_param = pd.read_pickle(sim_param_file)
net._compute_time_grid()

sim_end = sim_param['t_end']
            
num_intervals = len(intervals)

int_times = np.linspace(sim_param['t_start'],sim_param['t_end'],num_intervals+1)

#%% load inter_Ts

inter_T_files = sorted([f for f in transfiles if 'inter_trans_mat' in f])


inter_Ts = dict()

for tau_w in tau_ws:
    print(tau_w)
    inter_Ts[1/tau_w] = []
    inter_T_files_tau_w = sorted([f for f in inter_T_files if f'tau_w{tau_w:.3e}' in f])

    for file in inter_T_files_tau_w:
        res = ContTempNetwork.load_inter_T(os.path.join(lapl_intertransmatdir,file))
        inter_Ts[1/tau_w].extend(res['inter_T'][1/tau_w])
        


#%% transmats

# start after first delta 1 to have a symmetric evolution
start_time, start_k = net._get_closest_time(40)


T = dict()
for tau_w in tau_ws:
    print(tau_w)
    T[tau_w] = np.eye(net.num_nodes)
    for inter_T in inter_Ts[1/tau_w][start_k:]:
        T[tau_w] = T[tau_w] @ inter_T.toarray()

Trev = dict()
for tau_w in tau_ws:
    print(tau_w)
    Trev[tau_w] = np.eye(net.num_nodes)
    for inter_T in inter_Ts[1/tau_w][start_k:][::-1]:
        Trev[tau_w] = Trev[tau_w] @ inter_T.toarray()
        
        
for tau_w in tau_ws:
    T[tau_w] = T[tau_w]/T[tau_w].sum(1)[:,np.newaxis]
    Trev[tau_w] = Trev[tau_w]/Trev[tau_w].sum(1)[:,np.newaxis]
    

#%% load Tran mats

start_k = 1907
res = pd.read_pickle(os.path.join(datadir,f'trans_mats_start_k{start_k}.pickle'))

# res = pd.read_pickle(os.path.join(datadir,f'trans_mats.pickle'))

T = res['T']
Trev = res['Trev']

#%%

tau_w = 10

fc = FlowClustering(T=T[tau_w])
fcrev = FlowClustering(T=Trev[tau_w])


S = fc._S
Srev = fcrev._S


A = net.compute_static_adjacency_matrix().toarray()

A = (A+A.T)/2

stat_clust = static_clustering(A, t=1,discrete_time_rw=True)        

B = stat_clust._S


#%% flow integral stab
tau_w = 10

p1 = np.ones(net.num_nodes)/net.num_nodes

Iforw_file = os.path.join(datadir, 'integralgrid', file_prefix + f'_tau_w{tau_w:.3e}' + \
                     '_PT_000003_to_000036.pickle')
Iback_file = os.path.join(datadir, 'integralgrid', file_prefix + f'_tau_w{tau_w:.3e}' + \
                     '_PT_000036_to_000003.pickle')
Iforw_res = pd.read_pickle(Iforw_file)
Iback_res = pd.read_pickle(Iback_file)

Iforw = (1/Iforw_res['integration_time']) * np.diag(p1) @ Iforw_res['ITPT'] @ np.diag(p1) - np.outer(p1,p1)
Iback = (1/Iback_res['integration_time']) * np.diag(p1) @ Iback_res['ITPT'] @ np.diag(p1) - np.outer(p1,p1)




#%% make data for flow diagram

# tau_w = 120

clustforw_file = os.path.join(datadir, 'clustersI', 'clusters_' + file_prefix + f'_tau_w{tau_w:.3e}' + \
                     '_PT_000003_to_000036.pickle')
clustback_file = os.path.join(datadir, 'clustersI', 'clusters_' + file_prefix + f'_tau_w{tau_w:.3e}' + \
                     '_PT_000036_to_000003.pickle')
clustforw_res = pd.read_pickle(clustforw_file)
clustback_res = pd.read_pickle(clustback_file)

source_comms = clustforw_res['best_cluster_sym']
target_comms = clustback_res['best_cluster_sym']
class_dict = {i : set(range(i*sim_param['n_per_group'],
                            (i+1)*sim_param['n_per_group'])) for i in range(sim_param['n_groups'])}

flows = []
for clas, clas_set in class_dict.items():
    for s, comm_s in enumerate(source_comms):
        for t, comm_t in enumerate(target_comms):
            val = len(clas_set.intersection(comm_s).intersection(comm_t))
            if val > 0:
                flows.append({'source': s, 'target': t, 'type': clas, 'value': val})

df_flows = pd.DataFrame.from_dict(flows)

# start from 1
df_flows.source += 1
df_flows.target += 1
df_flows.type += 1

import floweaver as flo

color_list =["#613D94",
"#D0730F",
"#359455",]

sources_list = df_flows.source.unique().tolist()

sources_list = np.array(sources_list)

targets_list = df_flows.target.unique().tolist()

targets_list = np.array(targets_list)

class_list = sorted(df_flows.type.unique().tolist())


nodes = {
    'forward': flo.ProcessGroup(sources_list.tolist()),
    'backward': flo.ProcessGroup(targets_list.tolist()),
}

ordering = [
    ['forward'],       # put "forward" on the left...
    ['backward'],   # ... and "backward" on the right.
]

bundles = [
    flo.Bundle('forward', 'backward'),
]
# The first argument is the dimension name -- for now we're using
# "process" to group by process ids. The second argument is a list
# of groups.
source_part = flo.Partition.Simple('process', sources_list.tolist())

# This is another partition.
target_part = flo.Partition.Simple('process', targets_list.tolist())

# Update the ProcessGroup nodes to use the partitions
nodes['forward'].partition = source_part
nodes['backward'].partition = target_part
nodes['forward'].title = 'Forward'
nodes['backward'].title = 'Backward'

# Another partition -- but this time the dimension is the "type"
# column of the flows table
class_flow_part = flo.Partition.Simple('type', class_list)

# Set the colours for the labels in the partition.
palette = {cl : col for cl, col in zip(class_list, color_list)}

# New SDD with the flow_partition set
sdd = flo.SankeyDefinition(nodes, bundles, ordering,
                       flow_partition=class_flow_part)

w = flo.weave(sdd, df_flows, palette=palette)

#%%
import json
with open(os.path.join(figdir,'synthnet_flow.json'), 'w') as fopen:
    json.dump(w.to_json(), fopen)
    
#%% clustering t1 vs t2

clust1 = set(range(9))
clust2 = set(range(9,18))
clust3 = set(range(18,27))

part123 = [clust1, clust2, clust3]


part12 = [clust1.union(clust2), clust3]
part13 = [clust1.union(clust3), clust2]
part23 = [clust2.union(clust3), clust1]


t2s = []
t2s_rev = []



res = {'avg_nclust' : [],
       'nvarinf' : [],
       'mostc_clust' : [],
       'nmi_123' : [],
       'nmi_12' : [],
       'nmi_13' : [],
       'nmi_23' : [],
       'max_nmis' : [],
       'best_cluster' : [],
       'best_stab' : []}

clust_res = {'sym' : deepcopy(res),
             }

# time reversed
clust_res_rev = deepcopy(clust_res)


def analyse_res(res, typ):
    
    
    nmis = np.array([norm_mutual_information(res['best_cluster_' + typ], part123),
            norm_mutual_information(res['best_cluster_' + typ], part12),
            norm_mutual_information(res['best_cluster_' + typ], part13),
            norm_mutual_information(res['best_cluster_' + typ], part23)])
    max_nmis = np.argmax(nmis)
    
    # if there is a pair of clust in the most common clust
    has_clust_pairs = [norm_mutual_information(res['best_cluster_' + typ], part123) == 1.0,
                   clust1.union(clust2) in res['best_cluster_' + typ],
                   clust1.union(clust3) in res['best_cluster_' + typ],
                   clust2.union(clust3) in res['best_cluster_' + typ]]
    
    
    nmi_eq1 = np.argwhere(np.array(nmis == 1))
    
    clust_pairs = np.argwhere(has_clust_pairs)
    
    return (res['avg_nclust_' + typ], res['nvarinf_' + typ], res['best_cluster_' + typ], 
            max_nmis, nmis[0], nmis[1], nmis[2], nmis[3], nmi_eq1, clust_pairs)
            
    
        
#%%

clust_t1t2_sym = dict()

starts = list(range(0,37,3))
stops = list(range(0,37,3))

for int_start in starts[1:]:
    clust_t1t2_sym[int_times[int_start]] = dict()

    for int_stop in stops[1:]:
        if int_start != int_stop:
            res = pd.read_pickle(os.path.join(clusterdir, 'clusters_' + \
                                  file_prefix + '_tau_w' + \
                          '{0:.3e}_PT_{1:06d}_to_{2:06d}.pickle'.format(tau_w, 
                           int_start,
                           int_stop)))
            
            #find if one of the clustering has a nmi of 1
            symresana = analyse_res(res, 'sym')
            nmi_eq1 = symresana[-1].ravel()
            
            if nmi_eq1.shape == (0,):
                nmi_eq1 = np.nan
                
            clust_t1t2_sym[int_times[int_start]][int_times[int_stop]] = float(nmi_eq1)
            
#%%

def add_vspan(ax, dt1=sim_param['deltat1'], dt2=sim_param['deltat2'],
              alpha=0.25):
    
    ax.axvspan(0, dt2, alpha=alpha, color="#613D94")
    ax.axvspan(1*dt1+dt2, dt1+2*dt2, alpha=alpha, color="#613D94")
    ax.axvspan(2*dt1+2*dt2, 2*dt1+3*dt2, alpha=alpha, color="#613D94")
    
def add_hspan(ax, dt1=sim_param['deltat1'], dt2=sim_param['deltat2'],
              alpha=0.25):
    
    ax.axhspan(0, dt2, alpha=alpha, color="#613D94")
    ax.axhspan(1*dt1+dt2, dt1+2*dt2, alpha=alpha, color="#613D94")
    ax.axhspan(2*dt1+2*dt2, 2*dt1+3*dt2, alpha=alpha, color="#613D94")  
    
df = pd.DataFrame.from_dict(clust_t1t2_sym)

fig, ax = plt.subplots()

import matplotlib.colors as colors

cmap = colors.ListedColormap(['#D13204', "#613D94", "#D0730F", "#359455",])

boundaries = [-0.5, 0.5, 1.5, 2.5, 3.5]

norm = colors.BoundaryNorm(boundaries, cmap.N, clip=True)

# cmap =["#613D94",
# "#D0730F",
# "#359455",]


for t1 in df.columns:
    ax.scatter(df.index - sim_param['deltat1'], 
                       [t1 - sim_param['deltat1']]*df.index.size, 
                       c=df[t1],
               cmap=cmap,
               vmin=0,
               vmax=cmap.N)

ax.axis('square')

ax.set_xlabel(r'$t_2$')
ax.set_ylabel(r'$t_1$')
ax.set_ylim(-20,460)
ax.set_xlim(-20,460)

add_vspan(ax, alpha=0.1)
add_hspan(ax, alpha=0.1)

# custom artists for custom legend

c123 = plt.Line2D((0,0),(0,0), color=cmap(0), marker='o', linestyle='')
c12 = plt.Line2D((0,0),(0,0), color=cmap(1), marker='o', linestyle='')
c13 = plt.Line2D((0,0),(0,0), color=cmap(2), marker='o', linestyle='')
c23 = plt.Line2D((0,0),(0,0), color=cmap(3), marker='o', linestyle='')

plt.legend([c123,c12,c13,c23], ['[1],[2],[3]', '[1,2],[3]', '[1,3],[2]', '[1],[2,3]'],
           bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)


#%% final plot for paper

from mpl_toolkits.axes_grid1.inset_locator import inset_axes
            
mpl.rcParams['lines.linewidth'] = 1.7
mpl.rcParams['font.size'] = 12


fig, axes = plt.subplots(2,4, figsize=(12,5.5))

ms = np.zeros_like(axes)
        
ms[0][1] = axes[0][1].matshow(T[tau_w], cmap=seqcmap)
axes[0][1].set_title(r'$T(t_1,t_2)$',)
     
ms[0][2] =axes[0][2].matshow(B, cmap=divcmap, vmin=-np.abs(B).max(),
                                    vmax=np.abs(B).max())
axes[0][2].set_title(r'$B(t_1,t_2)$')

ms[0][3] =axes[0][3].matshow(S, cmap=divcmap, vmin=-np.abs(S).max(),
                                    vmax=np.abs(S).max())
axes[0][3].set_title(r'$P_1T(t_1,t_2)-p_1^Tp_2$')


ms[1][0] = axes[1][0].matshow(Iforw,
                              cmap=divcmap,
                                    vmin=-np.abs(Iforw).max(),
                                    vmax=np.abs(Iforw).max())
axes[1][0].set_title(r'$\int_{t_1}^{t_2} S_{forw}(t_1,t)dt$')


ms[1][1] = axes[1][1].matshow(Iback,
                              cmap=divcmap,
                                    vmin=-np.abs(Iback).max(),
                                    vmax=np.abs(Iback).max())
axes[1][1].set_title(r'$\int_{t_2}^{t_1} S_{back}(t_2,t)dt$')





for t1 in df.columns:
    axes[1][3].scatter(df.index - sim_param['deltat1'], 
                       [t1 - sim_param['deltat1']]*df.index.size, 
                       c=df[t1],
               cmap=cmap,
               vmin=0,
               vmax=cmap.N)

axes[1][3].axis('square')

axes[1][3].set_xlabel(r'$t_{end}$')
axes[1][3].set_ylabel(r'$t_{start}$')
axes[1][3].set_ylim(-20,460)
axes[1][3].set_xlim(-20,460)

add_vspan(axes[1][3], alpha=0.1)
add_hspan(axes[1][3], alpha=0.1)

# custom artists for custom legend

c123 = plt.Line2D((0,0),(0,0), color=cmap(0), marker='o', linestyle='')
c12 = plt.Line2D((0,0),(0,0), color=cmap(1), marker='o', linestyle='')
c13 = plt.Line2D((0,0),(0,0), color=cmap(2), marker='o', linestyle='')
c23 = plt.Line2D((0,0),(0,0), color=cmap(3), marker='o', linestyle='')

plt.legend([c123,c12,c13,c23], ['[1],[2],[3]', '[1,2],[3]', '[1,3],[2]', '[1],[2,3]'],
           bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)




letters = np.array([r'\textbf{A}', r'\textbf{B}', r'\textbf{C}',
           r'\textbf{D}', r'\textbf{E}', r'\textbf{F}',r'\textbf{G}',r'\textbf{H}',])
    

import types

def bottom_offset(self, bboxes, bboxes2):
    bottom = self.axes.bbox.ymin
    self.offsetText.set(va="top", ha="left")
    self.offsetText.set_position(
            (0, bottom - self.OFFSETTEXTPAD * self.figure.dpi / 72.0))

def register_bottom_offset(axis, func):
    axis._update_offset_text_position = types.MethodType(func, axis)




ind_to_plot = np.array([1,2,3,4,5,7])

for m, ax, letter in zip(ms.flatten()[ind_to_plot], axes.flatten()[ind_to_plot],
                         letters[ind_to_plot]):
    if m !=0:
        # create inset for colorbar
        axins = inset_axes(ax,
                   width="5%",  # width = 5% of parent_bbox width
                   height="100%",
                   loc='lower left',
                   bbox_to_anchor=(1.05, 0., 1, 1),
                   bbox_transform=ax.transAxes,
                   borderpad=0,
                   )
        
        cb = plt.colorbar(m,cax=axins)
        ax.xaxis.tick_bottom()
    ax.text(-0.25,1.1, letter, transform=ax.transAxes, useTex=True)
#    register_bottom_offset(cb.ax.yaxis, bottom_offset)
#    cb.update_ticks()

# axes[0][0].text(-0.25,1.0, letters[0], transform=ax.transAxes, useTex=True)
axes[0][0].axis('off')
axes[1][2].axis('off')

plt.tight_layout()

#%%
plt.savefig(os.path.join(figdir, 'paper_example_trans_covar_matrices_tau_w{0:.2e}.png'.format(tau_w)), dpi=600)
plt.savefig(os.path.join(figdir, 'paper_example_trans_covar_matrices_tau_w{0:.2e}.pdf'.format(tau_w)))
#
